{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ISubpr_SSsiM"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "3jTMb1dySr3V"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6DWfyNThSziV"
      },
      "source": [
        "# Introduction to modules, layers, and models\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/intro_to_modules\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/intro_to_modules.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/intro_to_modules.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/intro_to_modules.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v0DdlfacAdTZ"
      },
      "source": [
        "To do machine learning in TensorFlow, you are likely to need to define, save, and restore a model.\n",
        "\n",
        "A model is, abstractly: \n",
        "\n",
        "* A function that computes something on tensors (a **forward pass**)\n",
        "* Some variables that can be updated in response to training\n",
        "\n",
        "In this guide, you will go below the surface of Keras to see how TensorFlow models are defined. This looks at how TensorFlow collects variables and models, as well as how they are saved and restored.\n",
        "\n",
        "Note: If you instead want to immediately get started with Keras, please see [the collection of Keras guides](./keras/).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VSa6ayJmfZxZ"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "goZwOXp_xyQj"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import keras\n",
        "from datetime import datetime\n",
        "\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yt5HEbsYAbw1"
      },
      "source": [
        "## TensorFlow Modules\n",
        "\n",
        "Most models are made of layers. Layers are functions with a known mathematical structure that can be reused and have trainable variables.  In TensorFlow, most high-level implementations of layers and models, such as Keras or [Sonnet](https://github.com/deepmind/sonnet), are built on the same foundational class: `tf.Module`.\n",
        "\n",
        "### Building Modules\n",
        "\n",
        "Here's an example of a very simple `tf.Module` that operates on a scalar tensor:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "alhYPVEtAiSy"
      },
      "outputs": [],
      "source": [
        "class SimpleModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.a_variable = tf.Variable(5.0, name=\"train_me\")\n",
        "    self.non_trainable_variable = tf.Variable(5.0, trainable=False, name=\"do_not_train_me\")\n",
        "  def __call__(self, x):\n",
        "    return self.a_variable * x + self.non_trainable_variable\n",
        "\n",
        "simple_module = SimpleModule(name=\"simple\")\n",
        "\n",
        "simple_module(tf.constant(5.0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JwMc_zu5Ant8"
      },
      "source": [
        "Modules and, by extension, layers are deep-learning terminology for \"objects\": they have internal state, and methods that use that state.\n",
        "\n",
        "There is nothing special about `__call__` except to act like a [Python callable](https://stackoverflow.com/questions/111234/what-is-a-callable); you can invoke your models with whatever functions you wish.\n",
        "\n",
        "You can set the trainability of variables on and off for any reason, including freezing layers and variables during fine-tuning.\n",
        "\n",
        "Note: `tf.Module` is the base class for both `tf.keras.layers.Layer` and `tf.keras.Model`, so everything you come across here also applies in Keras.  For historical compatibility reasons Keras layers do not collect variables from modules, so your models should use only modules or only Keras layers.  However, the methods shown below for inspecting variables are the same in either case.\n",
        "\n",
        "By subclassing `tf.Module`, any `tf.Variable` or `tf.Module` instances assigned to this object's properties are automatically collected.  This allows you to save and load variables, and also create collections of `tf.Module`s."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CyzYy4A_CbVf"
      },
      "outputs": [],
      "source": [
        "# All trainable variables\n",
        "print(\"trainable variables:\", simple_module.trainable_variables)\n",
        "# Every variable\n",
        "print(\"all variables:\", simple_module.variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nuSFrRUNCaaW"
      },
      "source": [
        "This is an example of a two-layer linear layer model made out of modules.\n",
        "\n",
        "First a dense (linear) layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Efb2p2bzAn-V"
      },
      "outputs": [],
      "source": [
        "class Dense(tf.Module):\n",
        "  def __init__(self, in_features, out_features, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.w = tf.Variable(\n",
        "      tf.random.normal([in_features, out_features]), name='w')\n",
        "    self.b = tf.Variable(tf.zeros([out_features]), name='b')\n",
        "  def __call__(self, x):\n",
        "    y = tf.matmul(x, self.w) + self.b\n",
        "    return tf.nn.relu(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bAhMuC-UpnhX"
      },
      "source": [
        "And then the complete model, which makes two layer instances and applies them:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QQ7qQf-DFw74"
      },
      "outputs": [],
      "source": [
        "class SequentialModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "\n",
        "    self.dense_1 = Dense(in_features=3, out_features=3)\n",
        "    self.dense_2 = Dense(in_features=3, out_features=2)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# You have made a model!\n",
        "my_model = SequentialModule(name=\"the_model\")\n",
        "\n",
        "# Call it, with random results\n",
        "print(\"Model results:\", my_model(tf.constant([[2.0, 2.0, 2.0]])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d1oUzasJHHXf"
      },
      "source": [
        "`tf.Module` instances will automatically collect, recursively, any `tf.Variable` or `tf.Module` instances assigned to it. This allows you to manage collections of `tf.Module`s with a single model instance, and save and load whole models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JLFA5_PEGb6C"
      },
      "outputs": [],
      "source": [
        "print(\"Submodules:\", my_model.submodules)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6lzoB8pcRN12"
      },
      "outputs": [],
      "source": [
        "for var in my_model.variables:\n",
        "  print(var, \"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hoaxL3zzm0vK"
      },
      "source": [
        "### Waiting to create variables\n",
        "\n",
        "You may have noticed here that you have to define both input and output sizes to the layer.  This is so the `w` variable has a known shape and can be allocated.\n",
        "\n",
        "By deferring variable creation to the first time the module is called with a specific input shape, you do not need specify the input size up front."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XsGCLFXlnPum"
      },
      "outputs": [],
      "source": [
        "class FlexibleDenseModule(tf.Module):\n",
        "  # Note: No need for `in_features`\n",
        "  def __init__(self, out_features, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.is_built = False\n",
        "    self.out_features = out_features\n",
        "\n",
        "  def __call__(self, x):\n",
        "    # Create variables on first call.\n",
        "    if not self.is_built:\n",
        "      self.w = tf.Variable(\n",
        "        tf.random.normal([x.shape[-1], self.out_features]), name='w')\n",
        "      self.b = tf.Variable(tf.zeros([self.out_features]), name='b')\n",
        "      self.is_built = True\n",
        "\n",
        "    y = tf.matmul(x, self.w) + self.b\n",
        "    return tf.nn.relu(y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8bjOWax9LOkP"
      },
      "outputs": [],
      "source": [
        "# Used in a module\n",
        "class MySequentialModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "\n",
        "    self.dense_1 = FlexibleDenseModule(out_features=3)\n",
        "    self.dense_2 = FlexibleDenseModule(out_features=2)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "my_model = MySequentialModule(name=\"the_model\")\n",
        "print(\"Model results:\", my_model(tf.constant([[2.0, 2.0, 2.0]])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "49JfbhVrpOLH"
      },
      "source": [
        "This flexibility is why TensorFlow layers often only need to specify the shape of their outputs, such as in `tf.keras.layers.Dense`, rather than both  the input and output size."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JOLVVBT8J_dl"
      },
      "source": [
        "### Saving weights\n",
        "\n",
        "You can save a `tf.Module` as both a [checkpoint](./checkpoint.ipynb) and a [SavedModel](./saved_model.ipynb).\n",
        "\n",
        "Checkpoints are just the weights (that is, the values of the set of variables inside the module and its submodules):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pHXKRDk7OLHA"
      },
      "outputs": [],
      "source": [
        "chkp_path = \"my_checkpoint\"\n",
        "checkpoint = tf.train.Checkpoint(model=my_model)\n",
        "checkpoint.write(chkp_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WXOPMBR4T4ZR"
      },
      "source": [
        "Checkpoints consist of two kinds of files: the data itself and an index file for metadata. The index file keeps track of what is actually saved and the numbering of checkpoints, while the checkpoint data contains the variable values and their attribute lookup paths."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jBV3fprlTWqJ"
      },
      "outputs": [],
      "source": [
        "!ls my_checkpoint*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CowCuBTvXgUu"
      },
      "source": [
        "You can look inside a checkpoint to be sure the whole collection of variables is saved, sorted by the Python object that contains them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o2QAdfpvS8tB"
      },
      "outputs": [],
      "source": [
        "tf.train.list_variables(chkp_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4eGaNiQWcK4j"
      },
      "source": [
        "During distributed (multi-machine) training they can be sharded,  which is why they are numbered (e.g., '00000-of-00001').  In this case, though, there is only one shard.\n",
        "\n",
        "When you load models back in, you overwrite the values in your Python object."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UV8rdDzcwVVg"
      },
      "outputs": [],
      "source": [
        "new_model = MySequentialModule()\n",
        "new_checkpoint = tf.train.Checkpoint(model=new_model)\n",
        "new_checkpoint.restore(\"my_checkpoint\")\n",
        "\n",
        "# Should be the same result as above\n",
        "new_model(tf.constant([[2.0, 2.0, 2.0]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BnPwDRwamdfq"
      },
      "source": [
        "Note: As checkpoints are at the heart of long training workflows `tf.checkpoint.CheckpointManager` is a helper class that makes checkpoint management much easier. Refer to the [Training checkpoints guide](./checkpoint.ipynb) for more details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pSZebVuWxDXu"
      },
      "source": [
        "### Saving functions\n",
        "\n",
        "TensorFlow can run models without the original Python objects, as demonstrated by [TensorFlow Serving](https://tensorflow.org/tfx) and [TensorFlow Lite](https://tensorflow.org/lite), even when you download a trained model from [TensorFlow Hub](https://tensorflow.org/hub).\n",
        "\n",
        "TensorFlow needs to know how to do the computations described in Python, but **without the original code**. To do this, you can make a **graph**, which is described in the [Introduction to graphs and functions guide](./intro_to_graphs.ipynb).\n",
        "\n",
        "This graph contains operations, or *ops*, that implement the function.\n",
        "\n",
        "You can define a graph in the model above by adding the `@tf.function` decorator to indicate that this code should run as a graph."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WQTvkapUh7lk"
      },
      "outputs": [],
      "source": [
        "class MySequentialModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "\n",
        "    self.dense_1 = Dense(in_features=3, out_features=3)\n",
        "    self.dense_2 = Dense(in_features=3, out_features=2)\n",
        "\n",
        "  @tf.function\n",
        "  def __call__(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# You have made a model with a graph!\n",
        "my_model = MySequentialModule(name=\"the_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hW66YXBziLo9"
      },
      "source": [
        "The module you have made works exactly the same as before.  Each unique signature passed into the function creates a separate graph. Check the [Introduction to graphs and functions guide](./intro_to_graphs.ipynb) for details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H5zUfti3iR52"
      },
      "outputs": [],
      "source": [
        "print(my_model([[2.0, 2.0, 2.0]]))\n",
        "print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lbGlU1kgyDo7"
      },
      "source": [
        "You can visualize the graph by tracing it within a TensorBoard summary."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zmy-T67zhp-S"
      },
      "outputs": [],
      "source": [
        "# Set up logging.\n",
        "stamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "logdir = \"logs/func/%s\" % stamp\n",
        "writer = tf.summary.create_file_writer(logdir)\n",
        "\n",
        "# Create a new model to get a fresh trace\n",
        "# Otherwise the summary will not see the graph.\n",
        "new_model = MySequentialModule()\n",
        "\n",
        "# Bracket the function call with\n",
        "# tf.summary.trace_on() and tf.summary.trace_export().\n",
        "tf.summary.trace_on(graph=True)\n",
        "tf.profiler.experimental.start(logdir)\n",
        "# Call only one tf.function when tracing.\n",
        "z = print(new_model(tf.constant([[2.0, 2.0, 2.0]])))\n",
        "with writer.as_default():\n",
        "  tf.summary.trace_export(\n",
        "      name=\"my_func_trace\",\n",
        "      step=0,\n",
        "      profiler_outdir=logdir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gz4lwNZ9hR79"
      },
      "source": [
        "Launch TensorBoard to view the resulting trace:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V4MXDbgBnkJu"
      },
      "outputs": [],
      "source": [
        "#docs_infra: no_execute\n",
        "%tensorboard --logdir logs/func"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gjattu0AhYUl"
      },
      "source": [
        "![A screenshot of the graph in TensorBoard](images/tensorboard_graph.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SQu3TVZecmL7"
      },
      "source": [
        "### Creating a `SavedModel`\n",
        "\n",
        "The recommended way of sharing completely trained models is to use `SavedModel`.  `SavedModel` contains both a collection of functions and a collection of weights. \n",
        "\n",
        "You can save the model you have just trained as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Awv_Tw__WK7a"
      },
      "outputs": [],
      "source": [
        "tf.saved_model.save(my_model, \"the_saved_model\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SXv3mEKsefGj"
      },
      "outputs": [],
      "source": [
        "# Inspect the SavedModel in the directory\n",
        "!ls -l the_saved_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vQQ3hEvHYdoR"
      },
      "outputs": [],
      "source": [
        "# The variables/ directory contains a checkpoint of the variables \n",
        "!ls -l the_saved_model/variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xBqPop7ZesBU"
      },
      "source": [
        "The `saved_model.pb` file is a [protocol buffer](https://developers.google.com/protocol-buffers) describing the functional `tf.Graph`.\n",
        "\n",
        "Models and layers can be loaded from this representation without actually making an instance of the class that created it.  This is desired in situations where you do not have (or want) a Python interpreter, such as serving at scale or on an edge device, or in situations where the original Python code is not available or practical to use.\n",
        "\n",
        "You can load the model as new object:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zRFcA5wIefv4"
      },
      "outputs": [],
      "source": [
        "new_model = tf.saved_model.load(\"the_saved_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-9EF3mT7i3qN"
      },
      "source": [
        "`new_model`, created from loading a saved model, is an internal TensorFlow user object without any of the class knowledge. It is not of type `SequentialModule`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EC_eQj7yi54G"
      },
      "outputs": [],
      "source": [
        "isinstance(new_model, SequentialModule)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-OrOX1zxiyhR"
      },
      "source": [
        "This new model works on the already-defined input signatures. You can't add more signatures to a model restored like this."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_23BYYBWfKnc"
      },
      "outputs": [],
      "source": [
        "print(my_model([[2.0, 2.0, 2.0]]))\n",
        "print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qSFhoMtTjSR6"
      },
      "source": [
        "Thus, using `SavedModel`, you are able to save TensorFlow weights and graphs using `tf.Module`, and then load them again."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rb9IdN7hlUZK"
      },
      "source": [
        "## Keras models and layers\n",
        "\n",
        "Note that up until this point, there is no mention of Keras. You can build your own high-level API on top of `tf.Module`, and people have.  \n",
        "\n",
        "In this section, you will examine how Keras uses `tf.Module`.  A complete user guide to Keras models can be found in the [Keras guide](https://www.tensorflow.org/guide/keras/sequential_model).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ds08u3touwe4t"
      },
      "source": [
        "Keras layers and models have a lot more extra features including:\n",
        "\n",
        "* Optional losses\n",
        "* Support for [metrics](https://keras.io/api/layers/base_layer/#add_metric-method)\n",
        "* Built-in support for an optional `training` argument to differentiate between training and inference use\n",
        "* Saving and restoring python objects instead of just black-box functions\n",
        "* `get_config` and `from_config` methods that allow you to accurately store configurations to allow model cloning in Python\n",
        "\n",
        "These features allow for far more complex models through subclassing, such as a custom GAN or a Variational AutoEncoder (VAE) model. Read about them in the [full guide](./keras/custom_layers_and_models.ipynb) to custom layers and models.\n",
        "\n",
        "Keras models also come with extra functionality that makes them easy to train, evaluate, load, save, and even train on multiple machines."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uigsVGPreE-D"
      },
      "source": [
        "### Keras layers\n",
        "\n",
        "`tf.keras.layers.Layer` is the base class of all Keras layers, and it inherits from `tf.Module`.\n",
        "\n",
        "You can convert a module into a Keras layer just by swapping out the parent and then changing `__call__` to `call`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "88YOGquhnQRd"
      },
      "outputs": [],
      "source": [
        "class MyDense(tf.keras.layers.Layer):\n",
        "  # Adding **kwargs to support base Keras layer arguments\n",
        "  def __init__(self, in_features, out_features, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "\n",
        "    # This will soon move to the build step; see below\n",
        "    self.w = tf.Variable(\n",
        "      tf.random.normal([in_features, out_features]), name='w')\n",
        "    self.b = tf.Variable(tf.zeros([out_features]), name='b')\n",
        "  def call(self, x):\n",
        "    y = tf.matmul(x, self.w) + self.b\n",
        "    return tf.nn.relu(y)\n",
        "\n",
        "simple_layer = MyDense(name=\"simple\", in_features=3, out_features=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYGmAsPrws--"
      },
      "source": [
        "Keras layers have their own `__call__` that does some bookkeeping described in the next section and then calls `call()`. You should notice no change in functionality."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nIqE8wOznYKG"
      },
      "outputs": [],
      "source": [
        "simple_layer([[2.0, 2.0, 2.0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tmN5vb1K18U1"
      },
      "source": [
        "### The `build` step\n",
        "\n",
        "As noted, it's convenient in many cases to wait to create variables until you are sure of the input shape.\n",
        "\n",
        "Keras layers come with an extra lifecycle step that allows you more flexibility in how you define your layers. This is defined in the `build` function.\n",
        "\n",
        "`build` is called exactly once, and it is called with the shape of the input. It's usually used to create variables (weights).\n",
        "\n",
        "You can rewrite `MyDense` layer above to be flexible to the size of its inputs:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4YTfrlgdsURp"
      },
      "outputs": [],
      "source": [
        "class FlexibleDense(tf.keras.layers.Layer):\n",
        "  # Note the added `**kwargs`, as Keras supports many arguments\n",
        "  def __init__(self, out_features, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "    self.out_features = out_features\n",
        "\n",
        "  def build(self, input_shape):  # Create the state of the layer (weights)\n",
        "    self.w = tf.Variable(\n",
        "      tf.random.normal([input_shape[-1], self.out_features]), name='w')\n",
        "    self.b = tf.Variable(tf.zeros([self.out_features]), name='b')\n",
        "\n",
        "  def call(self, inputs):  # Defines the computation from inputs to outputs\n",
        "    return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "# Create the instance of the layer\n",
        "flexible_dense = FlexibleDense(out_features=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Koc_uSqt2PRh"
      },
      "source": [
        "At this point, the model has not been built, so there are no variables:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DgyTyUD32Ln4"
      },
      "outputs": [],
      "source": [
        "flexible_dense.variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-KdamIVl2W8Y"
      },
      "source": [
        "Calling the function allocates appropriately-sized variables:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IkLyEx7uAoTK"
      },
      "outputs": [],
      "source": [
        "# Call it, with predictably random results\n",
        "print(\"Model results:\", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Swofpkrd2YDd"
      },
      "outputs": [],
      "source": [
        "flexible_dense.variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7PuNUnf0OIpF"
      },
      "source": [
        "Since `build` is only called once, inputs will be rejected if the input shape is not compatible with the layer's variables:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "caYWDrHSAy_j"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  print(\"Model results:\", flexible_dense(tf.constant([[2.0, 2.0, 2.0, 2.0]])))\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(\"Failed:\", e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L2kds2IHw2KD"
      },
      "source": [
        "### Keras models\n",
        "\n",
        "You can define your model as nested Keras layers.\n",
        "\n",
        "However, Keras also provides a full-featured model class called `tf.keras.Model`. It inherits from `tf.keras.layers.Layer`, so a Keras model can be used and nested in the same way as Keras layers. Keras models come with extra functionality that makes them easy to train, evaluate, load, save, and even train on multiple machines.\n",
        "\n",
        "You can define the `SequentialModule` from above with nearly identical code, again converting `__call__` to `call()` and changing the parent:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hqjo1DiyrHrn"
      },
      "outputs": [],
      "source": [
        "@keras.saving.register_keras_serializable()\n",
        "class MySequentialModel(tf.keras.Model):\n",
        "  def __init__(self, name=None, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "\n",
        "    self.dense_1 = FlexibleDense(out_features=3)\n",
        "    self.dense_2 = FlexibleDense(out_features=2)\n",
        "  def call(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# You have made a Keras model!\n",
        "my_sequential_model = MySequentialModel(name=\"the_model\")\n",
        "\n",
        "# Call it on a tensor, with random results\n",
        "print(\"Model results:\", my_sequential_model(tf.constant([[2.0, 2.0, 2.0]])))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8i-CR_h2xw3z"
      },
      "source": [
        "All the same features are available, including tracking variables and submodules.\n",
        "\n",
        "Note: A raw `tf.Module` nested inside a Keras layer or model will not get its variables collected for training or saving.  Instead, nest Keras layers inside of Keras layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hdLQFNdMsOz1"
      },
      "outputs": [],
      "source": [
        "my_sequential_model.variables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JjVAMrAJsQ7G"
      },
      "outputs": [],
      "source": [
        "my_sequential_model.submodules"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FhP8EItC4oac"
      },
      "source": [
        "Overriding `tf.keras.Model` is a very Pythonic approach to building TensorFlow models.  If you are migrating models from other frameworks, this can be very straightforward.\n",
        "\n",
        "If you are constructing models that are simple assemblages of existing layers and inputs, you can save time and space by using the [functional API](./keras/functional.ipynb), which comes with additional features around model reconstruction and architecture.\n",
        "\n",
        "Here is the same model with the functional API:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jJiZZiJ0fyqQ"
      },
      "outputs": [],
      "source": [
        "inputs = tf.keras.Input(shape=[3,])\n",
        "\n",
        "x = FlexibleDense(3)(inputs)\n",
        "x = FlexibleDense(2)(x)\n",
        "\n",
        "my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)\n",
        "\n",
        "my_functional_model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kg-xAZw5gaG6"
      },
      "outputs": [],
      "source": [
        "my_functional_model(tf.constant([[2.0, 2.0, 2.0]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s_BK9XH5q9cq"
      },
      "source": [
        "The major difference here is that the input shape is specified up front as part of the functional construction process. The `input_shape` argument in this case does not have to be completely specified; you can leave some dimensions as `None`.\n",
        "\n",
        "Note: You do not need to specify `input_shape` or an `InputLayer` in a subclassed model; these arguments and layers will be ignored."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qI9aXLnaHEFF"
      },
      "source": [
        "### Saving Keras models\n",
        "\n",
        "Keras models have their own specialized zip archive saving format, marked by the `.keras` extension. When calling `tf.keras.Model.save`, add a `.keras` extension to the filename. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SAz-KVZlzAJu"
      },
      "outputs": [],
      "source": [
        "my_sequential_model.save(\"exname_of_file.keras\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C2urAeR-omns"
      },
      "source": [
        "Just as easily, they can be loaded back in:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wj5DW-LCopry"
      },
      "outputs": [],
      "source": [
        "reconstructed_model = tf.keras.models.load_model(\"exname_of_file.keras\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EA7P_MNvpviZ"
      },
      "source": [
        "Keras zip archives — `.keras` files — also save metric, loss, and optimizer states.\n",
        "\n",
        "This reconstructed model can be used and will produce the same result when called on the same data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P_wGfQo5pe6T"
      },
      "outputs": [],
      "source": [
        "reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "seLIUG2354s"
      },
      "source": [
        "### Checkpointing Keras models\n",
        "\n",
        "Keras models can also be checkpointed, and that will look the same as `tf.Module`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xKyjlkceqjwD"
      },
      "source": [
        "There is more to know about saving and serialization of Keras models, including providing configuration methods for custom layers for feature support. Check out the [guide to saving and serialization](https://www.tensorflow.org/guide/keras/save_and_serialize)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kcdMMPYv7Krz"
      },
      "source": [
        "# What's next\n",
        "\n",
        "If you want to know more details about Keras, you can follow the existing Keras guides [here](./keras/).\n",
        "\n",
        "Another example of a high-level API built on `tf.module` is Sonnet from DeepMind, which is covered on [their site](https://github.com/deepmind/sonnet)."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "ISubpr_SSsiM"
      ],
      "name": "intro_to_modules.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
